# Copyright (c) 2023  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import datetime
import functools
import logging
import threading
import time

import colorlog

loggers = {}

log_config = {
    "DEBUG": {"level": 10, "color": "purple"},
    "INFO": {"level": 20, "color": "green"},
    "TRAIN": {"level": 21, "color": "cyan"},
    "EVAL": {"level": 22, "color": "blue"},
    "WARNING": {"level": 30, "color": "yellow"},
    "ERROR": {"level": 40, "color": "red"},
    "CRITICAL": {"level": 50, "color": "bold_red"},
}


class Logger(object):
    """
    Deafult logger in PaddleFleetX

    Args:
        name(str) : Logger name, default is 'PaddleFleetX'
    """

    def __init__(self, name: str = None):
        name = "PaddleFleetX" if not name else name
        self.logger = logging.getLogger(name)

        for key, conf in log_config.items():
            logging.addLevelName(conf["level"], key)
            self.__dict__[key] = functools.partial(self.__call__, conf["level"])
            self.__dict__[key.lower()] = functools.partial(self.__call__, conf["level"])

        self.format = colorlog.ColoredFormatter(
            "%(log_color)s[%(asctime)-15s] [%(levelname)s]%(reset)s - %(message)s",
            log_colors={key: conf["color"] for key, conf in log_config.items()},
        )

        self.handler = logging.StreamHandler()
        self.handler.setFormatter(self.format)

        self.logger.addHandler(self.handler)
        self.logLevel = "DEBUG"
        self.logger.setLevel(logging.DEBUG)
        self.logger.propagate = False
        self._is_enable = True

    def disable(self):
        self._is_enable = False

    def enable(self):
        self._is_enable = True

    @property
    def is_enable(self) -> bool:
        return self._is_enable

    def __call__(self, log_level: str, msg: str):
        if not self.is_enable:
            return

        self.logger.log(log_level, msg)

    @contextlib.contextmanager
    def use_terminator(self, terminator: str):
        old_terminator = self.handler.terminator
        self.handler.terminator = terminator
        yield
        self.handler.terminator = old_terminator

    @contextlib.contextmanager
    def processing(self, msg: str, interval: float = 0.1):
        """
        Continuously print a progress bar with rotating special effects.

        Args:
            msg(str): Message to be printed.
            interval(float): Rotation interval. Default to 0.1.
        """
        end = False

        def _printer():
            index = 0
            flags = ["\\", "|", "/", "-"]
            while not end:
                flag = flags[index % len(flags)]
                with self.use_terminator("\r"):
                    self.info("{}: {}".format(msg, flag))
                time.sleep(interval)
                index += 1

        t = threading.Thread(target=_printer)
        t.start()
        yield
        end = True


logger = Logger()


def advertise():
    """
    Show the advertising message like the following:
    ===========================================================
    ==        PaddleFleetX is powered by PaddlePaddle !        ==
    ===========================================================
    ==                                                       ==
    ==   For more info please go to the following website.   ==
    ==                                                       ==
    ==       https://github.com/PaddlePaddle/PaddleFleetX    ==
    ===========================================================
    """
    copyright = "PaddleFleetX is powered by PaddlePaddle !"
    ad = "For more info please go to the following website."
    website = "https://github.com/PaddlePaddle/PaddleFleetX"
    AD_LEN = 6 + len(max([copyright, ad, website], key=len))

    logger.info(
        "\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
            "=" * (AD_LEN + 4),
            "=={}==".format(copyright.center(AD_LEN)),
            "=" * (AD_LEN + 4),
            "=={}==".format(" " * AD_LEN),
            "=={}==".format(ad.center(AD_LEN)),
            "=={}==".format(" " * AD_LEN),
            "=={}==".format(website.center(AD_LEN)),
            "=" * (AD_LEN + 4),
        )
    )


from .device import synchronize  # noqa: E402


def get_timestamp():
    if synchronize():
        return time.time()
    else:
        logger.warning("Device synchronizing failed, which may result uncorrect time")
    return time.time()


def convert_timestamp_to_data(timeStamp):
    return str(datetime.timedelta(seconds=int(timeStamp)))
